import asyncio
import hashlib
import mimetypes
import os
import re
import shutil
import tempfile
import zipfile
from pathlib import Path
from urllib.parse import parse_qsl, unquote, urlparse

import aiohttp
import structlog
from multidict import CIMultiDictProxy
from yarl import URL

from skyvern.config import settings
from skyvern.constants import BROWSER_DOWNLOAD_TIMEOUT, BROWSER_DOWNLOADING_SUFFIX, REPO_ROOT_DIR
from skyvern.exceptions import DownloadFileMaxSizeExceeded, DownloadFileMaxWaitingTime
from skyvern.forge.sdk.api.aws import AsyncAWSClient, aws_client
from skyvern.utils.url_validators import encode_url

LOG = structlog.get_logger()


async def download_from_s3(client: AsyncAWSClient, s3_uri: str) -> str:
    downloaded_bytes = await client.download_file(uri=s3_uri)
    filename = s3_uri.split("/")[-1]  # Extract filename from the end of S3 URI
    file_path = create_named_temporary_file(delete=False, file_name=filename)
    LOG.info(f"Downloaded file to {file_path.name}")
    file_path.write(downloaded_bytes)
    return file_path.name


def get_file_name_and_suffix_from_headers(headers: CIMultiDictProxy[str]) -> tuple[str, str]:
    file_stem = ""
    file_suffix: str | None = ""
    # retrieve the stem and suffix from Content-Disposition
    content_disposition = headers.get("Content-Disposition")
    if content_disposition:
        filename = re.findall('filename="(.+)"', content_disposition, re.IGNORECASE)
        if len(filename) > 0:
            file_stem = Path(filename[0]).stem
            file_suffix = Path(filename[0]).suffix

    if file_suffix:
        return file_stem, file_suffix

    # retrieve the suffix from Content-Type
    content_type = headers.get("Content-Type")
    if content_type:
        if file_suffix := mimetypes.guess_extension(content_type):
            return file_stem, file_suffix

    return file_stem, file_suffix or ""


def extract_google_drive_file_id(url: str) -> str | None:
    """Extract file ID from Google Drive URL."""
    # Handle format: https://drive.google.com/file/d/{file_id}/view
    match = re.search(r"/file/d/([a-zA-Z0-9_-]+)", url)
    if match:
        return match.group(1)
    return None


def is_valid_mime_type(file_path: str) -> bool:
    mime_type, _ = mimetypes.guess_type(file_path)
    return mime_type is not None


def validate_download_url(url: str) -> bool:
    """Validate if a URL is supported for downloading.

    Security validation for URL downloads to prevent:
    - File system access outside allowed directories
    - Access to local file system in non-local environments
    - Unsupported or dangerous URL schemes

    Args:
        url: The URL to validate

    Returns:
        True if valid, False otherwise.
    """
    try:
        parsed_url = urlparse(url)
        scheme = parsed_url.scheme.lower()

        # Allow http/https URLs (includes Google Drive which uses https)
        if scheme in ("http", "https"):
            return True

        # Allow S3 URIs for Skyvern uploads bucket
        if scheme == "s3":
            if url.startswith(f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"):
                return True
            return False

        # Allow file:// URLs only in local environment
        if scheme == "file":
            if settings.ENV != "local":
                return False

            # Validate the file path is within allowed directories
            try:
                file_path = parse_uri_to_path(url)
                allowed_prefix = f"{REPO_ROOT_DIR}/downloads"
                if not file_path.startswith(allowed_prefix):
                    return False
                return True
            except ValueError:
                return False

        # Reject unsupported schemes
        return False

    except Exception:
        return False


async def download_file(url: str, max_size_mb: int | None = None) -> str:
    try:
        # Check if URL is a Google Drive link
        if "drive.google.com" in url:
            file_id = extract_google_drive_file_id(url)
            if file_id:
                # Convert to direct download URL
                url = f"https://drive.google.com/uc?export=download&id={file_id}"
                LOG.info("Converting Google Drive link to direct download", url=url)

        # Check if URL is an S3 URI
        if url.startswith(f"s3://{settings.AWS_S3_BUCKET_UPLOADS}/{settings.ENV}/o_"):
            LOG.info("Downloading Skyvern file from S3", url=url)
            client = AsyncAWSClient()
            return await download_from_s3(client, url)

        # Check if URL is a file:// URI
        # we only support to download local files when the environment is local
        # and the file is in the skyvern downloads directory
        if url.startswith("file://") and settings.ENV == "local":
            file_path = parse_uri_to_path(url)
            if file_path.startswith(f"{REPO_ROOT_DIR}/downloads"):
                LOG.info("Downloading file from local file system", url=url)
                return file_path

        async with aiohttp.ClientSession(raise_for_status=True) as session:
            LOG.info("Starting to download file", url=url)
            encoded_url = encode_url(url)
            async with session.get(URL(encoded_url, encoded=True)) as response:
                # Check the content length if available
                if max_size_mb and response.content_length and response.content_length > max_size_mb * 1024 * 1024:
                    # todo: move to root exception.py
                    raise DownloadFileMaxSizeExceeded(max_size_mb)

                # Parse the URL
                a = urlparse(url)

                # Get the file name
                temp_dir = make_temp_directory(prefix="skyvern_downloads_")

                file_name = ""
                file_suffix = ""
                try:
                    file_name, file_suffix = get_file_name_and_suffix_from_headers(response.headers)
                    if not file_suffix:
                        LOG.warning("No extension name retrieved from HTTP headers")
                except Exception:
                    LOG.exception("Failed to retrieve the file extension from HTTP headers")

                # parse the query params to get the file name
                query_params = dict(parse_qsl(a.query))
                if "download" in query_params:
                    file_name = query_params["download"]

                if not file_name:
                    LOG.info("No file name retrieved from HTTP headers, using the file name from the URL")
                    file_name = os.path.basename(a.path)

                if not is_valid_mime_type(file_name) and file_suffix:
                    LOG.info("No file extension detected, adding the extension from HTTP headers")
                    file_name = file_name + file_suffix

                file_name = sanitize_filename(file_name)
                file_path = os.path.join(temp_dir, file_name)

                LOG.info(f"Downloading file to {file_path}")
                with open(file_path, "wb") as f:
                    # Write the content of the request into the file
                    total_bytes_downloaded = 0
                    async for chunk in response.content.iter_chunked(1024):
                        f.write(chunk)
                        total_bytes_downloaded += len(chunk)
                        if max_size_mb and total_bytes_downloaded > max_size_mb * 1024 * 1024:
                            raise DownloadFileMaxSizeExceeded(max_size_mb)

                LOG.info(f"File downloaded successfully to {file_path}")
                return file_path
    except aiohttp.ClientResponseError as e:
        LOG.error(f"Failed to download file, status code: {e.status}")
        raise
    except DownloadFileMaxSizeExceeded as e:
        LOG.error(f"Failed to download file, max size exceeded: {e.max_size}")
        raise
    except Exception:
        LOG.exception("Failed to download file")
        raise


def zip_files(files_path: str, zip_file_path: str) -> str:
    with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(files_path):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, files_path)  # Relative path within the zip
                zipf.write(file_path, arcname)

    return zip_file_path


def unzip_files(zip_file_path: str, output_dir: str) -> None:
    with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
        zip_ref.extractall(output_dir)


def get_path_for_workflow_download_directory(run_id: str | None) -> Path:
    return Path(get_download_dir(run_id=run_id))


def get_download_dir(run_id: str | None) -> str:
    download_dir = f"{REPO_ROOT_DIR}/downloads/{run_id}"
    os.makedirs(download_dir, exist_ok=True)
    return download_dir


def list_files_in_directory(directory: Path, recursive: bool = False) -> list[str]:
    listed_files: list[str] = []
    for root, dirs, files in os.walk(directory):
        listed_files.extend([os.path.join(root, file) for file in files])
        if not recursive:
            break

    return listed_files


def list_downloading_files_in_directory(
    directory: Path, downloading_suffix: str = BROWSER_DOWNLOADING_SUFFIX
) -> list[str]:
    # check if there's any file is still downloading
    downloading_files: list[str] = []
    for file in list_files_in_directory(directory):
        path = Path(file)
        if path.suffix == downloading_suffix:
            downloading_files.append(file)
    return downloading_files


async def wait_for_download_finished(downloading_files: list[str], timeout: float = BROWSER_DOWNLOAD_TIMEOUT) -> None:
    cur_downloading_files = downloading_files
    try:
        async with asyncio.timeout(timeout):
            while len(cur_downloading_files) > 0:
                new_downloading_files: list[str] = []
                for path in cur_downloading_files:
                    if path.startswith("s3://"):
                        try:
                            await aws_client.get_object_info(path)
                        except Exception:
                            LOG.debug(
                                "downloading file is not found in s3, means the file finished downloading", path=path
                            )
                            continue
                    else:
                        if not Path(path).exists():
                            LOG.debug(
                                "downloading file is not found in the local file system, means the file finished downloading",
                                path=path,
                            )
                            continue
                    new_downloading_files.append(path)
                cur_downloading_files = new_downloading_files
                await asyncio.sleep(1)
    except asyncio.TimeoutError:
        raise DownloadFileMaxWaitingTime(downloading_files=cur_downloading_files)


def get_number_of_files_in_directory(directory: Path, recursive: bool = False) -> int:
    return len(list_files_in_directory(directory, recursive))


def sanitize_filename(filename: str) -> str:
    return "".join(c for c in filename if c.isalnum() or c in ["-", "_", ".", "%", " "])


def rename_file(file_path: str, new_file_name: str) -> str:
    try:
        new_file_name = sanitize_filename(new_file_name)
        new_file_path = os.path.join(os.path.dirname(file_path), new_file_name)
        os.rename(file_path, new_file_path)
        return new_file_path
    except Exception:
        LOG.exception(f"Failed to rename file {file_path} to {new_file_name}")
        return file_path


def calculate_sha256_for_file(file_path: str) -> str:
    """Helper function to calculate SHA256 hash of a file."""
    sha256_hash = hashlib.sha256()
    with open(file_path, "rb") as f:
        for byte_block in iter(lambda: f.read(4096), b""):
            sha256_hash.update(byte_block)
    return sha256_hash.hexdigest()


def create_folder_if_not_exist(dir: str) -> None:
    path = Path(dir)
    path.mkdir(parents=True, exist_ok=True)


def get_skyvern_temp_dir() -> str:
    temp_dir = settings.TEMP_PATH
    create_folder_if_not_exist(temp_dir)
    return temp_dir


def make_temp_directory(
    suffix: str | None = None,
    prefix: str | None = None,
) -> str:
    temp_dir = settings.TEMP_PATH
    create_folder_if_not_exist(temp_dir)
    return tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=temp_dir)


def create_named_temporary_file(delete: bool = True, file_name: str | None = None) -> tempfile._TemporaryFileWrapper:
    temp_dir = settings.TEMP_PATH
    create_folder_if_not_exist(temp_dir)

    if file_name:
        # Sanitize the filename to remove any dangerous characters
        safe_file_name = sanitize_filename(file_name)
        # Create file with exact name (without random characters)
        file_path = os.path.join(temp_dir, safe_file_name)
        # Open in binary mode and return a NamedTemporaryFile-like object
        file = open(file_path, "wb")
        return tempfile._TemporaryFileWrapper(file, file_path, delete=delete)

    return tempfile.NamedTemporaryFile(dir=temp_dir, delete=delete)


def clean_up_dir(dir: str) -> None:
    if not os.path.exists(dir):
        return

    if os.path.isfile(dir):
        os.unlink(dir)
        return

    for item in os.listdir(dir):
        item_path = os.path.join(dir, item)
        if os.path.isfile(item_path) or os.path.islink(item_path):
            os.unlink(item_path)
        elif os.path.isdir(item_path):
            shutil.rmtree(item_path)

    return


def clean_up_skyvern_temp_dir() -> None:
    return clean_up_dir(get_skyvern_temp_dir())


def parse_uri_to_path(uri: str) -> str:
    parsed_uri = urlparse(uri)
    if parsed_uri.scheme != "file":
        raise ValueError(f"Invalid URI scheme: {parsed_uri.scheme} expected: file")
    path = parsed_uri.netloc + parsed_uri.path
    return unquote(path)
